% Generate Figure B1 and Table B3.

clear
clc

% Designate working directories
dir = '.../Replication package/';
cd(dir);
addpath(genpath(strcat(dir,"Code/src")))
addpath(genpath(strcat(dir,"Datasets")))

fileID = fopen(strcat(dir,'appl_RUR_test_statistic.txt'), 'w'); 

% Load the dataset
load('BBCCR');

alpha_level = 0.1;
bR = 1000;           % number of bootstrap repetitions
eps_constant = 1e-2; % cutoff value for determining numerical zeros
pso_nmax = 100;      % maximum number of PSO trials for non-decreasing solutions w.r.t. λ

lambdaset = [(1:(-0.1):0.2)]; % grid for penalty levels

% Set seed
seed = 4534546 ;
rng(seed,'twister')

bnd = 10;

[test, gammahat, nzero] = pen_Bierens_test(datau,dataw,bnd,lambdaset,eps_constant,pso_nmax);

n_nonzero = p - nzero;

bR_normal = normrnd(0, 1, [N, bR]);

std_mu = std(datau);
   cns = [0; std_mu];

RP_matrix = zeros(length(lambdaset),length(cns));
b_stat_matrix = zeros(bR,length(cns),length(lambdaset));  

for i_b = 1:bR
    mu_star = datau.*bR_normal(:,i_b); 

    for i_cns = 1:length(cns)
        mu_star_cns = mu_star + cns(i_cns) / sqrt(N);  
        
        [test_star, gammahat_star, nzero_star] ...
            = pen_Bierens_test(mu_star_cns, dataw, bnd, lambdaset, eps_constant, pso_nmax);
        
        b_stat_matrix(i_b ,i_cns, :) = test_star;
    end
end

for i_lambda = 1:length(lambdaset) 
    cv = quantile(b_stat_matrix(:, 1, i_lambda), (1 - alpha_level));  % set critical values

    RP = mean(b_stat_matrix(:, :, i_lambda) > cv);

    RP_matrix(i_lambda, :) = RP;  % first column comprises sizes, and second column shows local powers across lambda values.
end

%%%%%%%%%%%%%%%%%%%%%%%
% Plots for Figure A1.
%%%%%%%%%%%%%%%%%%%%%%%
figure('visible','off')
yyaxis left
plot(-lambdaset,test,'-','LineWidth',1.5);
ylim([1 5])
xlabel('Minus penalization parameter (-lambda)')
ylabel('Penalized test statistc')
title('Visualization of Test Statistics')
yyaxis right
plot(-lambdaset,n_nonzero,'--','LineWidth',1.5);
ylim([0 21])
ylabel('Number of selected covariates')
saveas(gcf,strcat(dir,'Figure_A1_top_left.jpg'))

figure('visible','off')
plot(-lambdaset,alpha_level*ones(length(lambdaset),1),'--','LineWidth',1.5,'DisplayName','size: B=0');
hold on
plot(-lambdaset,RP_matrix(:,2),'-','LineWidth',1.5,'DisplayName','power: B=std(U) > 0');
hold off
ylim([0 1])
xlabel('Minus penalization parameter (-lambda)')
ylabel('Rejection Probabiity')
title('Selection of penalization parameter')
legend
saveas(gcf,strcat(dir,'Figure_A1_top_right.jpg'))

figure('visible','off')
hold on
for i_p = 1:p
    var_name = w_var_names{i_p};
    if mod(i_p,3) == 0 
    plot(-lambdaset,gammahat(:,i_p),'-','LineWidth',1.5,'DisplayName',var_name)
    elseif mod(i_p,3) == 1 
    plot(-lambdaset,gammahat(:,i_p),'--','LineWidth',1.5,'DisplayName',var_name)
    elseif mod(i_p,3) == 2 
    plot(-lambdaset,gammahat(:,i_p),':','LineWidth',1.5,'DisplayName',var_name)
    end
end

hold off
lgd = legend;
lgd.NumColumns = 3;
lgd.Location = 'northwest';

ylim([-5 5])
xlabel('Minus penalization parameter (-lambda)')
ylabel('Coefficients (gammma)')
title('Visualization of Coefficients')
saveas(gcf,strcat(dir,'Figure_A1_bottom.jpg'))

gammahat_zero_var = (abs(gammahat) < eps_constant);

for i_lambda = 1:length(lambdaset)
    selected_covariates = w_var_names(gammahat_zero_var(i_lambda,:)==0);    

    disp('lambda');
    disp(lambdaset(i_lambda));
    disp('Selected Covariates');
    disp(selected_covariates);

    fprintf(fileID,...
        'The lambda is %8.3f and the number of selected covariates is %d\n',...
        lambdaset(i_lambda),(p-nzero(i_lambda)));

    fprintf(fileID,...
        'Selected Covariates\n');

    fprintf(fileID,...
        '%s - ',selected_covariates{:});
    fprintf(fileID, '\n');
end

fclose(fileID);

%%%%%%%%%%%%%%%%%%%%
% Table A3.
%%%%%%%%%%%%%%%%%%%%
i1 = find(abs(lambdaset-0.2) < 1e-2);
i2 = find(abs(lambdaset-0.3) < 1e-2);
fprintf('lambda \t Test stat. \t No. of selected cov.s \t Bootstrap p-value\n')
fprintf('%.1f\t%.3f\t$%d\t%.3f\n', [0.2 test(i1) p-nzero(i1) mean(b_stat_matrix(:,1,i1)>test(i1))])
fprintf('%.1f\t%.3f\t$%d\t%.3f\n', [0.3 test(i2) p-nzero(i2) mean(b_stat_matrix(:,1,i2)>test(i2))])